import taichi as ti

from ..shape import Shape
from ..utils import Vec3, Inf, Ray3
from ..config import EPS_SPHERE


class Sphere(Shape):
    """球体"""

    def __init__(self, r=1):
        super().__init__()
        Shape.shapes[self.index].attribute[0] = r

    @ti.func
    def intersect(self: int, ray: Ray3, closest: float):
        r = Shape.shapes[self].attribute[0]
        a = ray.direct.norm_sqr()
        b = -ray.origin.dot(ray.direct)
        discriminant = b * b - (ray.origin.norm_sqr() - r**2) * a
        root, norm = Inf, Vec3(0)
        if discriminant > 0:
            sqrt = discriminant**0.5
            root_ = (b - sqrt) / a
            if EPS_SPHERE <= root_ <= closest:
                root = root_
            else:
                root_ = (b + sqrt) / a
                if EPS_SPHERE <= root_ <= closest:
                    root = root_
            if root < Inf:
                point = ray.at(root)
                norm = point / r
        return root, norm
